# inspired by sqlalchemy-utils (https://github.com/kvesteri/sqlalchemy-utils)
import abc
import base64
from enum import Enum
from typing import Any, Callable, Optional, TYPE_CHECKING, Type, Union

import sqlalchemy.types as types
from pydantic.utils import lenient_issubclass
from sqlalchemy.engine.default import DefaultDialect

import ormar  # noqa: I100, I202
from ormar import ModelDefinitionError  # noqa: I202, I100

cryptography = None
try:  # pragma: nocover
    import cryptography  # type: ignore
    from cryptography.fernet import Fernet
    from cryptography.hazmat.backends import default_backend
    from cryptography.hazmat.primitives import hashes
except ImportError:  # pragma: nocover
    pass

if TYPE_CHECKING:  # pragma: nocover
    from ormar import BaseField


class EncryptBackend(abc.ABC):
    def _refresh(self, key: Union[str, bytes]) -> None:
        if isinstance(key, str):
            key = key.encode()
        digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
        digest.update(key)
        engine_key = digest.finalize()
        self._initialize_backend(engine_key)

    @abc.abstractmethod
    def _initialize_backend(self, secret_key: bytes) -> None:  # pragma: nocover
        pass

    @abc.abstractmethod
    def encrypt(self, value: Any) -> str:  # pragma: nocover
        pass

    @abc.abstractmethod
    def decrypt(self, value: Any) -> str:  # pragma: nocover
        pass


class HashBackend(EncryptBackend):
    """
    One-way hashing - in example for passwords, no way to decrypt the value!
    """

    def _initialize_backend(self, secret_key: bytes) -> None:
        self.secret_key = base64.urlsafe_b64encode(secret_key)

    def encrypt(self, value: Any) -> str:
        if not isinstance(value, str):  # pragma: nocover
            value = repr(value)
        value = value.encode()
        digest = hashes.Hash(hashes.SHA512(), backend=default_backend())
        digest.update(self.secret_key)
        digest.update(value)
        hashed_value = digest.finalize()
        return hashed_value.hex()

    def decrypt(self, value: Any) -> str:
        if not isinstance(value, str):  # pragma: nocover
            value = str(value)
        return value


class FernetBackend(EncryptBackend):
    """
    Two-way encryption, data stored in db are encrypted but decrypted during query.
    """

    def _initialize_backend(self, secret_key: bytes) -> None:
        self.secret_key = base64.urlsafe_b64encode(secret_key)
        self.fernet = Fernet(self.secret_key)

    def encrypt(self, value: Any) -> str:
        if not isinstance(value, str):
            value = repr(value)
        value = value.encode()
        encrypted = self.fernet.encrypt(value)
        return encrypted.decode("utf-8")

    def decrypt(self, value: Any) -> str:
        if not isinstance(value, str):  # pragma: nocover
            value = str(value)
        decrypted: Union[str, bytes] = self.fernet.decrypt(value.encode())
        if not isinstance(decrypted, str):
            decrypted = decrypted.decode("utf-8")
        return decrypted


class EncryptBackends(Enum):
    NONE = 0
    FERNET = 1
    HASH = 2
    CUSTOM = 3


BACKENDS_MAP = {
    EncryptBackends.FERNET: FernetBackend,
    EncryptBackends.HASH: HashBackend,
}


class EncryptedString(types.TypeDecorator):
    """
    Used to store encrypted values in a database
    """

    impl = types.TypeEngine

    def __init__(
        self,
        encrypt_secret: Union[str, Callable],
        encrypt_backend: EncryptBackends = EncryptBackends.FERNET,
        encrypt_custom_backend: Type[EncryptBackend] = None,
        **kwargs: Any,
    ) -> None:
        _field_type = kwargs.pop("_field_type")
        super().__init__()
        if not cryptography:  # pragma: nocover
            raise ModelDefinitionError(
                "In order to encrypt a column 'cryptography' is required!"
            )
        backend = BACKENDS_MAP.get(encrypt_backend, encrypt_custom_backend)
        if not backend or not lenient_issubclass(backend, EncryptBackend):
            raise ModelDefinitionError("Wrong or no encrypt backend provided!")

        self.backend: EncryptBackend = backend()
        self._field_type: "BaseField" = _field_type
        self._underlying_type: Any = _field_type.column_type
        self._key: Union[str, Callable] = encrypt_secret
        type_ = self._field_type.__type__
        if type_ is None:  # pragma: nocover
            raise ModelDefinitionError(
                f"Improperly configured field " f"{self._field_type.name}"
            )
        self.type_: Any = type_

    def __repr__(self) -> str:  # pragma: nocover
        return "TEXT()"

    def load_dialect_impl(self, dialect: DefaultDialect) -> Any:
        return dialect.type_descriptor(types.TEXT())

    def _refresh(self) -> None:
        key = self._key() if callable(self._key) else self._key
        self.backend._refresh(key)

    def process_bind_param(self, value: Any, dialect: DefaultDialect) -> Optional[str]:
        if value is None:
            return value
        self._refresh()
        try:
            value = self._underlying_type.process_bind_param(value, dialect)
        except AttributeError:
            encoder = ormar.ENCODERS_MAP.get(self.type_, None)
            if encoder:
                value = encoder(value)  # type: ignore

        encrypted_value = self.backend.encrypt(value)
        return encrypted_value

    def process_result_value(self, value: Any, dialect: DefaultDialect) -> Any:
        if value is None:
            return value
        self._refresh()
        decrypted_value = self.backend.decrypt(value)
        try:
            return self._underlying_type.process_result_value(decrypted_value, dialect)
        except AttributeError:
            decoder = ormar.DECODERS_MAP.get(self.type_, None)
            if decoder:
                return decoder(decrypted_value)  # type: ignore

            return self._field_type.__type__(decrypted_value)  # type: ignore
